#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
:File       :fc_net.py
:Description:
    全连接网络
    Extracted from https://gitee.com/kongfanhe/pytorch-tutorial/blob/master/test.py
:EditTime   :2025/05/03 19:30:58
:Author     :Kiumb
'''

import torch

class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x